Gemma4 HeteroTP support for NIXL KV connector#41169
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request refactors the NIXL connector to use a plan-based transfer architecture, introducing EngineTransferPlan to handle complex transfer geometries for hybrid SSM+Attention and HeteroTP models. While this architecture simplifies the transfer hot path, several critical bugs were identified: the splitting logic incorrectly assumes uniform parameters across all attention groups, and an incorrect boundary parameter breaks Mamba state transfers. Additionally, the block trimming logic must be updated to avoid corrupting Mamba state, and plan generation needs to be more robust when encountering homogeneous remote engines to prevent assertion failures.
| remote_tpb = tuple(nixl_agent_meta.tokens_per_block_per_group or ()) | ||
| assert len(remote_tpb) == len(self._group_kinds), ( | ||
| f"Remote tokens_per_block_per_group length " | ||
| f"{len(remote_tpb)} != {len(self._group_kinds)} groups" | ||
| ) |
There was a problem hiding this comment.
The assertion len(remote_tpb) == len(self._group_kinds) will fail when connecting to a remote engine that is homogeneous (e.g., an older vLLM version or a different model configuration). In such cases, nixl_agent_meta.tokens_per_block_per_group is None, causing remote_tpb to be an empty tuple, while self._group_kinds is always populated. For homogeneous remotes, remote_tpb should default to a tuple of nixl_agent_meta.block_size repeated for each group.
| remote_tpb = tuple(nixl_agent_meta.tokens_per_block_per_group or ()) | |
| assert len(remote_tpb) == len(self._group_kinds), ( | |
| f"Remote tokens_per_block_per_group length " | |
| f"{len(remote_tpb)} != {len(self._group_kinds)} groups" | |
| ) | |
| remote_tpb = tuple(nixl_agent_meta.tokens_per_block_per_group) if nixl_agent_meta.tokens_per_block_per_group is not None else (nixl_agent_meta.block_size,) * len(self._group_kinds) | |
| assert len(remote_tpb) == len(self._group_kinds), ( | |
| f"Remote tokens_per_block_per_group length " | |
| f"{len(remote_tpb)} != {len(self._group_kinds)} groups" | |
| ) |
| fa_slot = fa_slot_map.get(p_rank, 0) | ||
|
|
||
| handle: list[tuple[int, int, int]] = [] | ||
| for j, (addr, local_len, dev) in enumerate(src_blocks_data): | ||
| if j < num_fa_descs: | ||
| chunk = local_len // fa_num_splits |
There was a problem hiding this comment.
The splitting logic in build_local_splits_from_plan incorrectly assumes that all attention groups share the same split count and slot mapping as the first group (plan.source_ranks_per_group[0] and plan.rank_to_attention_slot[0]). In HeteroTP models like Gemma4, different attention groups (e.g., SWA vs. FA) can have different numbers of KV heads, which may result in different numbers of remote source ranks (due to GQA deduplication) and different head-to-slot assignments. Using group 0's parameters for all attention descriptors will lead to incorrect memory slicing and data corruption for other groups.
| for handle_data in build_local_splits_from_plan( | ||
| plan, | ||
| self.src_blocks_data, | ||
| self.num_descs, | ||
| ): |
There was a problem hiding this comment.
Passing self.num_descs as the num_fa_descs boundary to build_local_splits_from_plan is incorrect if self.num_descs represents the total number of descriptors (FA + SSM). This causes the logic in transfer_plan.py to treat all descriptors as attention descriptors (since j < num_fa_descs will always be true), and has_ssm_descs will be incorrectly evaluated as False. This will break the 3-read transfer logic for Mamba state.
| # Partial prefix cache hit: trim to the shorter of local/remote. | ||
| # After ReadSpec construction, local descriptor IDs and remote | ||
| # block IDs should already have matched lengths per group | ||
| # (gather-read pairing ensures this). Trim from the head to | ||
| # keep the tail (newest blocks). | ||
| remote_block_ids = list(remote_block_ids) | ||
| for i, remote_group in enumerate(remote_block_ids): | ||
| num_remote_blocks = len(remote_group) | ||
| num_local_blocks = len(local_block_ids[i]) | ||
| if not self._is_mamba_group[i]: | ||
| assert num_local_blocks <= num_remote_blocks | ||
| # Partial prefix cache hit: just read uncomputed blocks. | ||
| # Skip mamba groups — their blocks represent full state (conv+ssm), | ||
| # not per-token data, so trimming would corrupt the transfer. | ||
| if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: | ||
| remote_block_ids[i] = remote_group[-num_local_blocks:] | ||
| local_block_ids = list(local_block_ids) | ||
| for i in range(len(remote_block_ids)): | ||
| n_local = len(local_block_ids[i]) | ||
| n_remote = len(remote_block_ids[i]) | ||
| n = min(n_local, n_remote) | ||
| if n_local > n: | ||
| local_block_ids[i] = local_block_ids[i][-n:] | ||
| if n_remote > n: | ||
| remote_block_ids[i] = remote_block_ids[i][-n:] |
There was a problem hiding this comment.
The new trimming logic in _read_blocks removes the Mamba-specific check that previously prevented trimming for Mamba groups. Mamba blocks represent full state (conv + SSM) rather than incremental per-token data. Trimming these blocks based on a length mismatch between local and remote requests is likely to result in state corruption. The logic should continue to skip trimming for groups where is_ssm is true.
| # Partial prefix cache hit: trim to the shorter of local/remote. | |
| # After ReadSpec construction, local descriptor IDs and remote | |
| # block IDs should already have matched lengths per group | |
| # (gather-read pairing ensures this). Trim from the head to | |
| # keep the tail (newest blocks). | |
| remote_block_ids = list(remote_block_ids) | |
| for i, remote_group in enumerate(remote_block_ids): | |
| num_remote_blocks = len(remote_group) | |
| num_local_blocks = len(local_block_ids[i]) | |
| if not self._is_mamba_group[i]: | |
| assert num_local_blocks <= num_remote_blocks | |
| # Partial prefix cache hit: just read uncomputed blocks. | |
| # Skip mamba groups — their blocks represent full state (conv+ssm), | |
| # not per-token data, so trimming would corrupt the transfer. | |
| if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: | |
| remote_block_ids[i] = remote_group[-num_local_blocks:] | |
| local_block_ids = list(local_block_ids) | |
| for i in range(len(remote_block_ids)): | |
| n_local = len(local_block_ids[i]) | |
| n_remote = len(remote_block_ids[i]) | |
| n = min(n_local, n_remote) | |
| if n_local > n: | |
| local_block_ids[i] = local_block_ids[i][-n:] | |
| if n_remote > n: | |
| remote_block_ids[i] = remote_block_ids[i][-n:] | |
| for i in range(len(remote_block_ids)): | |
| if self._group_kinds[i].is_ssm: | |
| continue | |
| n_local = len(local_block_ids[i]) | |
| n_remote = len(remote_block_ids[i]) | |
| n = min(n_local, n_remote) | |
| if n_local > n: | |
| local_block_ids[i] = local_block_ids[i][-n:] | |
| if n_remote > n: | |
| remote_block_ids[i] = remote_block_ids[i][-n:] |
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…delBlockTransferPolicy Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…cal dict Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
… and remote_tp_size Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…fo, remove defaults, regroup args Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…mote and physical_blocks_per_logical, regroup Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…npack read spec loop Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…omputation Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…blocks_per_logical removal Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…_policy and use_mla, set num_descs Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…nsfer_info Add physical_blocks_per_logical to TransferTopology and pass transfer_topo directly to build_engine_transfer_info, reducing the method's parameter count from 10 to 6. Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…11→4) Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…_id (7→4) Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…to_kernel_block_ids Move model-specific block ID expansion and trimming logic out of worker.py hot paths into a unified logical_to_kernel_block_ids() in transfer_plan.py. Per-request functions now consume only the pre-computed EngineTransferPlan (model-agnostic); model awareness is confined to init and plan generation (if/else dense vs mamba). - Add transfer_plan.py with plan generators, executors, and logical_to_kernel_block_ids (per-group physical_per_logical). - Generate EngineTransferPlan during handshake (generate_dense_plan or generate_mamba_plan), stored in _transfer_plans dict. - Replace _logical_to_remote_kernel_block_ids in _read_blocks_for_req with plan-based logical_to_kernel_block_ids (no model branching). - Make block trimming in _read_blocks unconditional (SSM groups are no-op due to shared block table). - Thin-wrap _logical_to_kernel_block_ids for local expansion, delegating to the same unified function. - Add _conv_decomp to worker init for mamba plan generation. Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
The unified logical_to_kernel_block_ids had two bugs: 1. Dense remote: used ratio=1 instead of local kernel ratio 2. Mamba FA remote: used same value for stride and count, but old code used remote_ratio as stride, local_ratio as count Restore the original two-method design: - _logical_to_kernel_block_ids: local expansion (same as main) - _logical_to_remote_kernel_block_ids: remote mamba expansion (same as main) Only difference from main: remote_ratio comes from plan.remote_physical_blocks_per_logical instead of self._mamba_phys_ratio[engine_id]. Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…ake field Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
…nsion Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Per-group TP mapping, sub-descriptor splitting, and block ID remapping for heterogeneous attention models (SWA + FA groups with different total_num_kv_heads). Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
When D_TP < P_TP (e.g. 4p2d), local pages are larger than remote pages. Introduce gather-read: split local blocks into sub-descriptors matching the remote page size so NIXL can pair them for RDMA. - Add local_to_remote_page_ratio and remote_blocks_per_local_block to EngineTransferPlan; generate_gemma4_plan now handles both split-read (2p4d) and gather-read (4p2d) directions - Add build_fa_local_descs_for_gather_read for local sub-desc registration - Worker: register gather-read handles, bidirectional block trimming, scale local desc-space block count for gather-read - Fix FullAttentionSpec.merge() and SinkFullAttentionSpec.merge() to propagate total_num_kv_heads (was causing AssertionError on Gemma4) - Add unit tests for gather-read configs (4p2d, 4p1d) Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
- Rename sub_desc_index_per_group → remote_desc_offset_per_group and _remap_remote_blocks_to_subdesc_ids → _remap_remote_blocks_to_desc_ids to eliminate confusing "sub-descriptor" naming throughout. - Add _pair_gather_group helper with remainder handling for partial block fills in gather-read (FA groups). - Add length-match assertions in both _build_gather_read_specs and _read_blocks to catch descriptor/block ID pairing mismatches early. - Add diagnostic INFO/DEBUG logging for HeteroTP plan generation, ReadSpec construction, and _read_blocks transfer details. Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
ca3b517 to
56d12e7
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Depends on #40731